{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from os.path import join, basename, isdir\n",
    "from os import listdir, makedirs\n",
    "import shutil as sh\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['figure.figsize'] = 9, 6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sort Trials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HOME = os.environ[\"HOME\"]\n",
    "print(HOME)\n",
    "os.chdir(join(HOME, 'tmp/pretrained-runs/'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_contents(d, dtype=str):\n",
    "    with open(d, \"r\") as f:\n",
    "        return dtype(f.readline())\n",
    "\n",
    "def get_source_info(d):\n",
    "    dsource = get_contents(join(d, \"weights-source\"), str)\n",
    "    dsource = dsource.replace(\"/root/trained-models\", join(HOME, \"tmp/final-runs\"))\n",
    "    dsource += d[-5:]\n",
    "    duration = get_contents(join(dsource, \"duration\"), float)\n",
    "    gain = get_contents(join(d, \"tx-gain\"), float)\n",
    "    mode = \"nc\" if \"classic\" in dsource else \"nn\"\n",
    "    shared = \"shared\" if \"shared\" in dsource else \"private\"\n",
    "    return duration, gain, mode, shared"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "get_source_info(\"pretrained-neural-neural-1655856184-srn2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dirs = list(filter(lambda d: isdir(d), listdir('.')))\n",
    "dirs = list(filter(lambda d: 'neural' in d, dirs))\n",
    "durations = {}\n",
    "durations['private'] = {'nc': defaultdict(list), 'nn': defaultdict(list)}\n",
    "durations['shared'] = {'nc': defaultdict(list), 'nn': defaultdict(list)}\n",
    "for d in dirs:\n",
    "    t, gain, agents, preamble = get_source_info(d)\n",
    "    durations[preamble][agents][(t, gain)].append(d)\n",
    "for k in durations[\"shared\"][\"nc\"]:\n",
    "    durations[\"shared\"][\"nc\"][k].sort()\n",
    "for k in durations[\"shared\"][\"nn\"]:\n",
    "    durations[\"shared\"][\"nn\"][k].sort()\n",
    "for k in durations[\"private\"][\"nc\"]:\n",
    "    durations[\"private\"][\"nc\"][k].sort()\n",
    "for k in durations[\"private\"][\"nn\"]:\n",
    "    durations[\"private\"][\"nn\"][k].sort()\n",
    "# print(durations)\n",
    "# Remove shorter tests\n",
    "keys = list(durations[\"shared\"][\"nc\"].keys())\n",
    "for k in keys:\n",
    "    if k[0] < 120:\n",
    "        durations[\"shared\"][\"nc\"].pop(k)\n",
    "keys = list(durations[\"shared\"][\"nn\"].keys())\n",
    "for k in keys:\n",
    "    if k[0] < 300:\n",
    "        durations[\"shared\"][\"nn\"].pop(k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trial Outcomes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def success_rate(ms): \n",
    "    if ms.size <= 40:\n",
    "        return sum(ms < 0.1) / 40.\n",
    "    else:\n",
    "        return sum(ms < 0.1) / 66.\n",
    "\n",
    "def do_runtime(dirs, sharedstr, agentstr, timestr):\n",
    "    mins = []\n",
    "    means = []\n",
    "    maxes = []\n",
    "    for d in dirs:\n",
    "        try:\n",
    "            with open(join(d, 'results'), 'r') as f:\n",
    "                lcount, lmean, lmin, lmax, _ = f.readlines()\n",
    "                mins.append(float(lmin.strip().split()[-1]))\n",
    "                means.append(float(lmean.strip().split()[-1]))\n",
    "                maxes.append(float(lmax.strip().split()[-1]))\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "    plt.stem(means)\n",
    "    plt.yscale('log')\n",
    "    plt.title(\"{} preamble {}, {} seconds trials\".format(\n",
    "                sharedstr, agentstr, timestr))\n",
    "    plt.show()\n",
    "    print(\"Success rate\", success_rate(np.array(means)))\n",
    "    return mins, means, maxes\n",
    "\n",
    "for sharedstr in (\"shared\", \"private\"):\n",
    "    for agentstr in (\"nn\", \"nc\"):\n",
    "        print(' '.join([\"=====\", sharedstr.upper(), \n",
    "                       agentstr.upper(), \"=====\"]))\n",
    "        for time in durations[sharedstr][agentstr]:\n",
    "            do_runtime(durations[sharedstr][agentstr][time],\n",
    "                       sharedstr, agentstr, time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(durations['private']['nn'][(600, 16)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SNR vs BER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def success_rate(ms): \n",
    "    if ms.size <= 40:\n",
    "        return sum(ms < 0.1) / 40.\n",
    "    else:\n",
    "        return sum(ms < 0.1) / 66.\n",
    "\n",
    "def do_snr_ber(dirs, sharedstr, agentstr):\n",
    "    gains = []\n",
    "    ber10 = []\n",
    "    ber50 = []\n",
    "    ber90 = []\n",
    "#     baseline = 10. ** -np.arange(1, 6) \n",
    "    baseline = [1.04085880e-01, 9.88240808e-03, 7.21992638e-04, 9.59187110e-05, 9.10228289e-06]\n",
    "    for t, g in dirs:\n",
    "        gains.append(g)\n",
    "        mins = []\n",
    "        means = []\n",
    "        maxes = []\n",
    "        for d in dirs[(t,g)]:\n",
    "            try:\n",
    "                with open(join(d, 'results'), 'r') as f:\n",
    "                    lcount, lmean, lmin, lmax, _ = f.readlines()\n",
    "                    mins.append(float(lmin.strip().split()[-1]))\n",
    "                    means.append(float(lmean.strip().split()[-1]))\n",
    "                    maxes.append(float(lmax.strip().split()[-1]))\n",
    "            except:\n",
    "                pass\n",
    "        N = len(means)\n",
    "        if g > 13: N = N // 2\n",
    "        n10 = int(np.ceil(N * 0.1))\n",
    "        n50 = int(np.ceil(N * 0.5))\n",
    "        n90 = int(np.ceil(N * 0.9))\n",
    "        ber10.append(sorted(means)[n10])\n",
    "        ber50.append(sorted(means)[n50])\n",
    "        ber90.append(sorted(means)[n90])\n",
    "    # Convert to ndarrays and sort\n",
    "    gains = np.array(gains)\n",
    "    ber10 = np.array(ber10)\n",
    "    ber50 = np.array(ber50)\n",
    "    ber90 = np.array(ber90)\n",
    "    isort = np.argsort(gains)\n",
    "    gains = gains[isort]\n",
    "    ber10 = ber10[isort]\n",
    "    ber50 = ber50[isort]\n",
    "    ber90 = ber90[isort]\n",
    "    plt.plot(gains, baseline, label=\"Baseline\", marker='o')\n",
    "    plt.plot(gains, ber10, label=\"10%\", marker='o')\n",
    "    plt.plot(gains, ber50, label=\"50%\", marker='o')\n",
    "    plt.plot(gains, ber90, label=\"90%\", marker='o')\n",
    "    plt.title(\"{} preamble {}, BER vs Gain(TODO SNR)\".format(\n",
    "                sharedstr, agentstr))\n",
    "    plt.legend()\n",
    "    plt.yscale(\"log\")\n",
    "    plt.show()\n",
    "    return gains, ber10, ber50, ber90\n",
    "\n",
    "for sharedstr in (\"shared\", \"private\"):\n",
    "    for agentstr in (\"nn\", \"nc\"):\n",
    "        print(' '.join([\"=====\", sharedstr.upper(), \n",
    "                       agentstr.upper(), \"=====\"]))\n",
    "        gains, ber10, ber50, ber90 = do_snr_ber(durations[sharedstr][agentstr],\n",
    "                   sharedstr, agentstr)\n",
    "        out = {'simulation_equivalent_snr': [4.2, 8.4, 10.4, 12, 13, 15],\n",
    "               'empirical_snr': [6.52637972, 10.56335866, 12.80612474, 14.15178439, 15.49744404],\n",
    "               'empirical_baseline': [1.04085880e-01, 9.88240808e-03, 7.21992638e-04, 9.59187110e-05, 9.10228289e-06],\n",
    "               '10th-percentile': list(ber10),\n",
    "               '50th-percentile': list(ber50),\n",
    "               '90th-percentile': list(ber90),\n",
    "              }\n",
    "        with open(\"{}-{}-ber_curve.pkl\".format(sharedstr, agentstr), \"wb\") as f:\n",
    "            pickle.dump(out, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
